#include <torch/torch.h>
#include <iostream>

namespace idx = torch::indexing;
static idx::Slice all_select(idx::None, idx::None, idx::None);
static idx::Slice one_one_select(1, -1, idx::None);
static bool debug_on = false;

std::vector<char> get_the_bytes(std::string filename) {
    std::ifstream input(filename, std::ios::binary);
    std::vector<char> bytes(
        (std::istreambuf_iterator<char>(input)),
        (std::istreambuf_iterator<char>()));

    input.close();
    return bytes;
}

// cv "/home/chonghao/xue_research/supplementary/data/irradiation_v3/cv_all_data.pt"
// ci "/home/chonghao/xue_research/supplementary/data/irradiation_v3/ci_all_data.pt"
// eta "/home/chonghao/xue_research/supplementary/data/irradiation_v3/eta_all_data.pt"
// video "/home/chonghao/xue_research/supplementary/data/irradiation_v3/video_all_data.pt"

torch::Tensor get_tensor(std::string filename)
{
    std::vector<char> f = get_the_bytes(filename);
    // printf("get tensor file loaded as byte file\n");
    torch::IValue x = torch::pickle_load(f);
    // printf("pickle load value from byte file\n");
    torch::Tensor tensor = x.toTensor();
    // printf("transform to tensor\n");
    // std::cout << tensor.index({1,2,3}) << std::endl;
    std::cout << "load data from " << filename << ", finish" << std::endl;
    return tensor;
}


struct ReturnData {
    torch::Tensor cv;
    torch::Tensor ci;
    torch::Tensor eta;
    // torch::Tensor cv_ref;
    // torch::Tensor ci_ref;
    // torch::Tensor eta_ref;
    torch::Tensor v;
    // torch::Tensor v_ref;
    int index;
    // torch::Tensor index_ref;
    torch::Tensor ul;
    // torch::Tensor ul_ref;

    ReturnData(torch::Tensor _cv,
                torch::Tensor _ci,
                torch::Tensor _eta,
                // torch::Tensor _cv_ref,
                // torch::Tensor _ci_ref,
                // torch::Tensor _eta_ref,
                torch::Tensor _v,
                // torch::Tensor _v_ref,
                int _index,
                // torch::Tensor _index_ref,
                torch::Tensor _ul
                // torch::Tensor _ul_ref
                )
    {
        this->cv = _cv;
        this->ci = _ci;
        this->eta = _eta;
        // this->cv_ref = _cv_ref;
        // this->ci_ref = _ci_ref;
        // this->eta_ref = _eta_ref;
        this->v = _v;
        // this->v_ref = _v_ref;
        this->index = _index;
        // this->index_ref = _index_ref;
        this->ul = _ul;
        // this->ul_ref = _ul_ref;
    }
};

struct ReturnLabel {
    // torch::Tensor cv;
    // torch::Tensor ci;
    // torch::Tensor eta;
    torch::Tensor cv_ref;
    torch::Tensor ci_ref;
    torch::Tensor eta_ref;
    // torch::Tensor v;
    torch::Tensor v_ref;
    // torch::Tensor index;
    int index_ref;
    // torch::Tensor ul;
    torch::Tensor ul_ref;

    ReturnLabel(
        // torch::Tensor _cv,
        //         torch::Tensor _ci,
        //         torch::Tensor _eta,
                torch::Tensor _cv_ref,
                torch::Tensor _ci_ref,
                torch::Tensor _eta_ref,
                // torch::Tensor _v,
                torch::Tensor _v_ref,
                // torch::Tensor _index,
                int _index_ref,
                // torch::Tensor _ul,
                torch::Tensor _ul_ref)
    {
        // this->cv = _cv;
        // this->ci = _ci;
        // this->eta = _eta;
        this->cv_ref = _cv_ref;
        this->ci_ref = _ci_ref;
        this->eta_ref = _eta_ref;
        // this->v = _v;
        this->v_ref = _v_ref;
        // this->index = _index;
        this->index_ref = _index_ref;
        // this->ul = _ul;
        this->ul_ref = _ul_ref;
    }
};

struct ReturnItem {
    ReturnData rd;
    ReturnLabel rl;
    ReturnItem(ReturnData _rd, ReturnLabel _rl) :rd(_rd), rl(_rl) {
        // rd = _rd;
        // rl = _rl;
    }
};

class IrradiationVideoDataset{
private:
    torch::Tensor cv;
    torch::Tensor ci;
    torch::Tensor eta;
    torch::Tensor video;
    // torch::Tensor start_skip;
    // torch::Tensor skip_step;
    // torch::Tensor cnt;
    int start_skip;
    int skip_step;
    int cnt;
    torch::Tensor upper_lower;

public:
    IrradiationVideoDataset(char* cv_path, char* ci_path, char* eta_path, char* video_path, int _skip_step);

    ReturnItem get_item(size_t index);

    inline int get_len() {
        // if (debug_on) {
        //     std::cout << "cv shape: " << cv.sizes() << std::endl;
        //     std::cout << "ci shape: " << ci.sizes() << std::endl;
        //     std::cout << "eta shape: " << eta.sizes() << std::endl;
        //     std::cout << "video shape: " << video.sizes() << std::endl;
        //     std::cout << "start skip tensor: " << start_skip << std::endl;
        //     std::cout << "skip step tensor: " << skip_step << std::endl;
        //     std::cout << "cnt tensor: " << cnt << std::endl;
        //     // std::cout << "cnt item value: " << cnt.item<int>() << std::endl;
        //     // std::cout << "get len cnt item value: " << cnt.item<int>() << std::endl;
        // }
        return cnt;
    }
};

IrradiationVideoDataset::IrradiationVideoDataset(char* cv_path, char* ci_path, char* eta_path, char* video_path, int _skip_step) {
    // fs::path dir (data_path);
    // fs::path all_data_file (filename_pkl);
    // fs::path all_data_full_path = dir / all_data_file;

    // torch::jit::script::Module all_data = torch::jit::load(data_path);
    // cv = all_data.run_method("get_cv").toTensor();
    // ci = all_data.run_method("get_ci").toTensor();
    // eta = all_data.run_method("get_eta").toTensor();
    // video = all_data.run_method("get_video").toTensor();

    // use this for testing
    // cv = torch::randn({5000, 130, 130});
    // ci = torch::randn({5000, 130, 130});
    // eta = torch::randn({5000, 130, 130});
    // video = torch::randn({5000, 130, 130, 130});

    // use this for training
    cv = get_tensor(cv_path);
    ci = get_tensor(ci_path);
    eta = get_tensor(eta_path);
    video = get_tensor(video_path);

    video = video.to(torch::kFloat);
    video = video - 128.0;
    video = video / 128.0;
    int _start_skip = 9;
    // start_skip = torch::from_blob(&_start_skip, {1}, torch::kInt);
    start_skip = _start_skip;
    // skip_step = torch::from_blob(&_skip_step, {1}, torch::kInt);
    skip_step = _skip_step;
    int _cnt = cv.size(0) - _start_skip * 2 - _skip_step;
    // cnt = torch::from_blob(&_cnt, {1}, torch::kInt);
    cnt = _cnt;

    // if (debug_on) {
    //     std::cout << "cv shape: " << cv.sizes() << std::endl;
    //     std::cout << "ci shape: " << ci.sizes() << std::endl;
    //     std::cout << "eta shape: " << eta.sizes() << std::endl;
    //     std::cout << "video shape: " << video.sizes() << std::endl;
    //     std::cout << "start skip tensor: " << start_skip << std::endl;
    //     std::cout << "skip step tensor: " << skip_step << std::endl;
    //     std::cout << "cnt tensor: " << cnt << std::endl;
    //     // std::cout << "cnt item value: " << cnt.item<int>() << std::endl;
    // }
    
    torch::Tensor prob = torch::ones(cnt + skip_step) * 0.5;
    upper_lower = torch::bernoulli(prob);
}


ReturnItem IrradiationVideoDataset::get_item(size_t index)
{
    
    // prepare data
    torch::Tensor cv_data = this->cv.index({(int)index + start_skip, one_one_select, one_one_select}).unsqueeze_(0);
    torch::Tensor ci_data = this->ci.index({(int)index + start_skip, one_one_select, one_one_select}).unsqueeze_(0);
    torch::Tensor eta_data = this->eta.index({(int)index + start_skip, one_one_select, one_one_select}).unsqueeze_(0);
    torch::Tensor v_data = this->video.index({(int)index, all_select, all_select, all_select}).unsqueeze_(0);
    // torch::Tensor index_data = torch::from_blob(&index, {1}, torch::kInt);
    torch::Tensor ul_data = this->upper_lower.index({(int)index});

    // prepare label
    size_t index_ref = index + skip_step;
    torch::Tensor cv_data_ref = this->cv.index({(int)index_ref + start_skip, one_one_select, one_one_select}).unsqueeze_(0);
    torch::Tensor ci_data_ref = this->ci.index({(int)index_ref + start_skip, one_one_select, one_one_select}).unsqueeze_(0);
    torch::Tensor eta_data_ref = this->eta.index({(int)index_ref + start_skip, one_one_select, one_one_select}).unsqueeze_(0);
    torch::Tensor v_data_ref = this->video.index({(int)index_ref, all_select, all_select, all_select}).unsqueeze_(0);
    // torch::Tensor index_data_ref = torch::from_blob(&index_ref, {1}, torch::kInt);
    torch::Tensor ul_data_ref = this->upper_lower.index({(int)index_ref});
    
    ReturnData return_data(cv_data, ci_data, eta_data, v_data, index, ul_data);
    ReturnLabel return_label(cv_data_ref, ci_data_ref, eta_data_ref, v_data_ref, index_ref, ul_data_ref);

    ReturnItem return_item(return_data, return_label);
    return return_item;
}

